#!/usr/bin/env python3
"""Convert paper.md to paper.docx with fully rendered tables."""

import re
from pathlib import Path
from docx import Document
from docx.shared import Inches, Pt, Cm, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_TABLE_ALIGNMENT
from docx.oxml.ns import qn, nsdecls
from docx.oxml import parse_xml


def parse_yaml_front_matter(text):
    """Extract YAML front matter."""
    m = re.match(r'^---\n(.*?)\n---\n', text, re.DOTALL)
    if not m:
        return {}, text
    yaml_block = m.group(1)
    meta = {}
    # Simple parsing for our known fields
    for key in ['title', 'author', 'date', 'version']:
        km = re.search(rf'^{key}:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
        if km:
            meta[key] = km.group(1).strip('"')
    # Abstract is multi-line
    am = re.search(r'abstract:\s*\|\s*\n(.*?)(?=\n\w+:|$)', yaml_block, re.DOTALL)
    if am:
        meta['abstract'] = re.sub(r'\n\s{2,}', '\n', am.group(1)).strip()
    km = re.search(r'^keywords:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
    if km:
        meta['keywords'] = km.group(1).strip('"')
    jm = re.search(r'^jel:\s*"?(.*?)"?\s*$', yaml_block, re.MULTILINE)
    if jm:
        meta['jel'] = jm.group(1).strip('"')
    return meta, text[m.end():]


def add_formatted_run(paragraph, text, bold=False, italic=False, size=None, color=None, superscript=False):
    """Add a run with formatting."""
    run = paragraph.add_run(text)
    run.bold = bold
    run.italic = italic
    if size:
        run.font.size = Pt(size)
    if color:
        run.font.color.rgb = RGBColor(*color)
    if superscript:
        run.font.superscript = True
    return run


def add_rich_text(paragraph, text, base_size=11, base_italic=False):
    """Parse inline markdown (bold, italic, math, citations) and add runs."""
    # Process the text character by character with regex
    # Patterns: **bold**, *italic*, $math$, @citation, ---
    parts = []
    remaining = text

    while remaining:
        # Find the next special pattern
        patterns = [
            (r'\*\*\*(.*?)\*\*\*', 'bold_italic'),
            (r'\*\*(.*?)\*\*', 'bold'),
            (r'\*(.*?)\*', 'italic'),
            (r'\$\$(.*?)\$\$', 'display_math'),
            (r'\$(.*?)\$', 'math'),
            (r'\[(.*?)\]', 'bracket'),
            (r'@(\w+\d{4}\w?)', 'citation'),
            (r'---', 'emdash'),
        ]

        earliest_match = None
        earliest_pos = len(remaining)
        earliest_type = None

        for pat, ptype in patterns:
            m = re.search(pat, remaining)
            if m and m.start() < earliest_pos:
                earliest_match = m
                earliest_pos = m.start()
                earliest_type = ptype

        if earliest_match is None:
            # No more patterns - add rest as plain text
            if remaining:
                add_formatted_run(paragraph, remaining, size=base_size, italic=base_italic)
            break

        # Add text before the match
        if earliest_pos > 0:
            add_formatted_run(paragraph, remaining[:earliest_pos], size=base_size, italic=base_italic)

        # Add the formatted content
        if earliest_type == 'bold_italic':
            add_formatted_run(paragraph, earliest_match.group(1), bold=True, italic=True, size=base_size)
        elif earliest_type == 'bold':
            add_formatted_run(paragraph, earliest_match.group(1), bold=True, size=base_size)
        elif earliest_type == 'italic':
            add_formatted_run(paragraph, earliest_match.group(1), italic=True, size=base_size)
        elif earliest_type in ('math', 'display_math'):
            # Render math as italic with special chars
            math_text = earliest_match.group(1)
            # Convert common LaTeX to unicode
            math_text = math_text.replace(r'\ln', 'ln')
            math_text = math_text.replace(r'\text{Flow}', 'Flow')
            math_text = math_text.replace(r'\text{Channel}', 'Channel')
            math_text = math_text.replace(r'\text{CA}', 'CA')
            math_text = math_text.replace(r'\text{GDP}', 'GDP')
            math_text = math_text.replace(r'\text{KAOPEN}', 'KAOPEN')
            math_text = math_text.replace(r'\text{channel}', 'channel')
            math_text = math_text.replace(r'\text{total}', 'total')
            math_text = math_text.replace(r'\alpha', '\u03b1')
            math_text = math_text.replace(r'\beta', '\u03b2')
            math_text = math_text.replace(r'\gamma', '\u03b3')
            math_text = math_text.replace(r'\delta', '\u03b4')
            math_text = math_text.replace(r'\theta', '\u03b8')
            math_text = math_text.replace(r'\phi', '\u03c6')
            math_text = math_text.replace(r'\rho', '\u03c1')
            math_text = math_text.replace(r'\hat{\rho}', '\u03c1\u0302')
            math_text = math_text.replace(r'\hat{r}', 'r\u0302')
            math_text = math_text.replace(r'\hat{\beta}', '\u03b2\u0302')
            math_text = math_text.replace(r'\hat', '\u0302')
            math_text = math_text.replace(r'\Delta', '\u0394')
            math_text = math_text.replace(r'\sum', '\u2211')
            math_text = math_text.replace(r'\times', '\u00d7')
            math_text = math_text.replace(r'\cdot', '\u00b7')
            math_text = math_text.replace(r'\neq', '\u2260')
            math_text = math_text.replace(r'\approx', '\u2248')
            math_text = math_text.replace(r'\leq', '\u2264')
            math_text = math_text.replace(r'\geq', '\u2265')
            math_text = math_text.replace(r'\to', '\u2192')
            math_text = math_text.replace(r'\rightarrow', '\u2192')
            math_text = math_text.replace(r'\in', '\u2208')
            math_text = math_text.replace(r'\ldots', '\u2026')
            math_text = math_text.replace(r'\frac{', '')
            math_text = math_text.replace('}{', '/')
            # Clean remaining braces and subscripts
            math_text = math_text.replace('{', '').replace('}', '')
            # Convert _{...} subscripts (simple cases)
            math_text = re.sub(r'_(\w)', lambda m: subscript_char(m.group(1)), math_text)
            math_text = re.sub(r'\^(\w)', lambda m: superscript_char(m.group(1)), math_text)
            add_formatted_run(paragraph, math_text, italic=True, size=base_size)
        elif earliest_type == 'citation':
            add_formatted_run(paragraph, earliest_match.group(1), size=base_size)
        elif earliest_type == 'bracket':
            # Could be a citation like [e.g., @higgins1998]
            content = earliest_match.group(1)
            content = re.sub(r'@(\w+)', r'\1', content)
            add_formatted_run(paragraph, f'[{content}]', size=base_size)
        elif earliest_type == 'emdash':
            add_formatted_run(paragraph, '\u2014', size=base_size)

        remaining = remaining[earliest_match.end():]


def subscript_char(c):
    """Convert a character to unicode subscript if possible."""
    subs = {'0': '\u2080', '1': '\u2081', '2': '\u2082', '3': '\u2083',
            '4': '\u2084', '5': '\u2085', '6': '\u2086', '7': '\u2087',
            '8': '\u2088', '9': '\u2089', 'i': '\u1d62', 'j': '\u2c7c',
            'k': '\u2096', 't': '\u209c', 'n': '\u2099'}
    return subs.get(c, '_' + c)


def superscript_char(c):
    """Convert a character to unicode superscript if possible."""
    sups = {'0': '\u2070', '1': '\u00b9', '2': '\u00b2', '3': '\u00b3',
            '4': '\u2074', '5': '\u2075', '6': '\u2076', '7': '\u2077',
            '8': '\u2078', '9': '\u2079'}
    return sups.get(c, '^' + c)


def parse_markdown_table(lines):
    """Parse markdown table lines into header and rows."""
    if len(lines) < 2:
        return None, None

    def parse_row(line):
        # Split on | and strip
        cells = [c.strip() for c in line.split('|')]
        # Remove empty first/last from leading/trailing |
        if cells and cells[0] == '':
            cells = cells[1:]
        if cells and cells[-1] == '':
            cells = cells[:-1]
        return cells

    header = parse_row(lines[0])
    # Skip separator line (lines[1])
    rows = [parse_row(l) for l in lines[2:] if l.strip()]
    return header, rows


def add_table(doc, header, rows, title=None, notes=None):
    """Add a formatted table to the document."""
    if title:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.LEFT
        p.space_before = Pt(12)
        add_formatted_run(p, title, bold=True, size=11)

    ncols = len(header)
    nrows = len(rows) + 1  # +1 for header
    table = doc.add_table(rows=nrows, cols=ncols)
    table.alignment = WD_TABLE_ALIGNMENT.CENTER
    table.style = 'Table Grid'

    # Set table width
    table.autofit = True

    # Header row
    for j, cell_text in enumerate(header):
        cell = table.rows[0].cells[j]
        cell.text = ''
        p = cell.paragraphs[0]
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
        add_rich_text(p, cell_text, base_size=9, base_italic=False)
        for run in p.runs:
            run.bold = True
            run.font.size = Pt(9)
        # Shade header
        shading = parse_xml(f'<w:shd {nsdecls("w")} w:fill="D9E2F3"/>')
        cell._tc.get_or_add_tcPr().append(shading)

    # Data rows
    for i, row_data in enumerate(rows):
        for j, cell_text in enumerate(row_data):
            if j >= ncols:
                continue
            cell = table.rows[i + 1].cells[j]
            cell.text = ''
            p = cell.paragraphs[0]
            # First column left-aligned, rest centered
            p.alignment = WD_ALIGN_PARAGRAPH.LEFT if j == 0 else WD_ALIGN_PARAGRAPH.CENTER
            add_rich_text(p, cell_text, base_size=9)
            # Alternate row shading
            if i % 2 == 1:
                shading = parse_xml(f'<w:shd {nsdecls("w")} w:fill="F2F2F2"/>')
                cell._tc.get_or_add_tcPr().append(shading)

    # Set font size for all cells
    for row in table.rows:
        for cell in row.cells:
            for p in cell.paragraphs:
                for run in p.runs:
                    if run.font.size is None:
                        run.font.size = Pt(9)

    if notes:
        p = doc.add_paragraph()
        p.space_before = Pt(2)
        p.space_after = Pt(12)
        add_rich_text(p, notes, base_size=8, base_italic=True)

    return table


def set_cell_border(cell, **kwargs):
    """Set cell borders."""
    tc = cell._tc
    tcPr = tc.get_or_add_tcPr()
    tcBorders = parse_xml(f'<w:tcBorders {nsdecls("w")}></w:tcBorders>')
    for edge, val in kwargs.items():
        element = parse_xml(
            f'<w:{edge} {nsdecls("w")} w:val="{val.get("val", "single")}" '
            f'w:sz="{val.get("sz", "4")}" w:space="0" w:color="{val.get("color", "000000")}"/>'
        )
        tcBorders.append(element)
    tcPr.append(tcBorders)


def convert_paper():
    paper_dir = Path(__file__).parent
    md_path = paper_dir / 'paper.md'
    docx_path = paper_dir / 'gravity_paper.docx'

    text = md_path.read_text(encoding='utf-8')

    # Parse front matter
    meta, body = parse_yaml_front_matter(text)

    doc = Document()

    # --- Page setup ---
    for section in doc.sections:
        section.top_margin = Cm(2.54)
        section.bottom_margin = Cm(2.54)
        section.left_margin = Cm(2.54)
        section.right_margin = Cm(2.54)

    # --- Default font ---
    style = doc.styles['Normal']
    font = style.font
    font.name = 'Times New Roman'
    font.size = Pt(11)
    style.paragraph_format.space_after = Pt(6)
    style.paragraph_format.line_spacing = 1.15

    # Update heading styles
    for i in range(1, 4):
        hs = doc.styles[f'Heading {i}']
        hs.font.name = 'Times New Roman'
        hs.font.color.rgb = RGBColor(0, 0, 0)
        if i == 1:
            hs.font.size = Pt(16)
            hs.paragraph_format.space_before = Pt(24)
            hs.paragraph_format.space_after = Pt(12)
        elif i == 2:
            hs.font.size = Pt(13)
            hs.paragraph_format.space_before = Pt(18)
            hs.paragraph_format.space_after = Pt(8)
        elif i == 3:
            hs.font.size = Pt(11)
            hs.paragraph_format.space_before = Pt(12)
            hs.paragraph_format.space_after = Pt(6)

    # --- Title page ---
    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(6)
    add_formatted_run(p, meta.get('title', 'Untitled'), bold=True, size=18)

    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(4)
    add_formatted_run(p, meta.get('author', ''), size=12)

    p = doc.add_paragraph()
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    p.space_after = Pt(4)
    add_formatted_run(p, meta.get('date', ''), size=12)

    if 'version' in meta:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER
        p.space_after = Pt(18)
        add_formatted_run(p, f"Version: {meta['version']}", size=10, italic=True)

    # Abstract
    if 'abstract' in meta:
        p = doc.add_paragraph()
        p.alignment = WD_ALIGN_PARAGRAPH.LEFT
        p.space_after = Pt(6)
        add_formatted_run(p, 'Abstract', bold=True, size=11)

        p = doc.add_paragraph()
        p.paragraph_format.left_indent = Cm(1.27)
        p.paragraph_format.right_indent = Cm(1.27)
        p.space_after = Pt(12)
        add_rich_text(p, meta['abstract'], base_size=10)

    # Keywords and JEL
    if 'keywords' in meta:
        p = doc.add_paragraph()
        p.space_after = Pt(2)
        add_formatted_run(p, 'Keywords: ', bold=True, size=10)
        add_formatted_run(p, meta['keywords'], size=10)
    if 'jel' in meta:
        p = doc.add_paragraph()
        p.space_after = Pt(18)
        add_formatted_run(p, 'JEL Classification: ', bold=True, size=10)
        add_formatted_run(p, meta['jel'], size=10)

    # Page break after abstract
    doc.add_page_break()

    # --- Process body ---
    lines = body.split('\n')
    i = 0
    while i < len(lines):
        line = lines[i]

        # Skip empty lines
        if not line.strip():
            i += 1
            continue

        # Headings
        heading_match = re.match(r'^(#{1,3})\s+(.*)', line)
        if heading_match:
            level = len(heading_match.group(1))
            heading_text = heading_match.group(2).strip()
            h = doc.add_heading(level=level)
            add_rich_text(h, heading_text, base_size={1: 16, 2: 13, 3: 11}[level])
            i += 1
            continue

        # Display equations ($$...$$)
        if line.strip().startswith('$$'):
            eq_lines = [line]
            if not line.strip().endswith('$$') or line.strip() == '$$':
                i += 1
                while i < len(lines):
                    eq_lines.append(lines[i])
                    if lines[i].strip().endswith('$$'):
                        break
                    i += 1
            eq_text = ' '.join(eq_lines)
            eq_text = eq_text.replace('$$', '').strip()

            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            p.space_before = Pt(6)
            p.space_after = Pt(6)
            add_rich_text(p, eq_text, base_size=11, base_italic=True)
            i += 1
            continue

        # Table title (bold line starting with **Table)
        if line.strip().startswith('**Table'):
            table_title = line.strip().strip('*')
            # Collect table lines
            i += 1
            # Skip empty lines
            while i < len(lines) and not lines[i].strip():
                i += 1
            # Check if next is a table
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            # Check for notes
            notes = None
            while i < len(lines) and not lines[i].strip():
                i += 1
            if i < len(lines) and lines[i].strip().startswith('*Notes:'):
                notes = lines[i].strip().strip('*')
                # May span multiple lines
                i += 1
                while i < len(lines) and lines[i].strip() and not lines[i].strip().startswith('#') and not lines[i].strip().startswith('**'):
                    if lines[i].strip().startswith('|') or lines[i].strip().startswith('*Notes:'):
                        break
                    notes += ' ' + lines[i].strip()
                    i += 1

            if table_lines:
                header, rows = parse_markdown_table(table_lines)
                if header and rows:
                    add_table(doc, header, rows, title=table_title, notes=notes)
            continue

        # Inline table (starts with |) without preceding **Table title
        if line.strip().startswith('|'):
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            if table_lines:
                header, rows = parse_markdown_table(table_lines)
                if header and rows:
                    add_table(doc, header, rows)
            continue

        # Bullet points
        bullet_match = re.match(r'^(\s*)-\s+(.*)', line)
        if bullet_match:
            bullet_text = bullet_match.group(2)
            p = doc.add_paragraph(style='List Bullet')
            add_rich_text(p, bullet_text, base_size=11)
            i += 1
            continue

        # Regular paragraph (may span multiple lines)
        para_lines = [line]
        i += 1
        while i < len(lines):
            next_line = lines[i]
            # Stop at empty lines, headings, tables, bullets
            if not next_line.strip():
                break
            if re.match(r'^#{1,3}\s', next_line):
                break
            if next_line.strip().startswith('|'):
                break
            if next_line.strip().startswith('**Table'):
                break
            if next_line.strip().startswith('$$'):
                break
            if re.match(r'^\s*-\s+', next_line):
                break
            para_lines.append(next_line)
            i += 1

        para_text = ' '.join(l.strip() for l in para_lines)
        if para_text.strip():
            p = doc.add_paragraph()
            add_rich_text(p, para_text, base_size=11)

    # Save
    doc.save(str(docx_path))
    print(f"Saved: {docx_path}")
    print(f"Tables: {len(doc.tables)} (all rendered inline)")


if __name__ == '__main__':
    convert_paper()
